上章講完prioritized遇到的挑戰跟解決方案,這章節就開始講實作囉!
SumTree是種二元節點儲存方式,從上的根節點直到下面底部的節點,節點的值都是由下方的值網上組起來的。我們先來描述如何從樹的結構去找尋想要的值過程:
類別初始化。
def __init__(self,capacity):
self.capacity = capacity
self.data_pointer = 0
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity,dtype=object)
@property
def total_p(self):
return self.tree[0] # the root
新增節點的順序是從0開始更新,直至終點再從0開始。
def add(self,p,data):
tree_idx = self.data_pointer + self.capacity - 1 # 節點index
self.data[self.data_pointer] = data # 底部index賦予值
self.update(tree_idx,p)
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
之前提過上面節點的value都是基於最下面的節點,所以一旦有新的值更新,上面的父節點也會對差值做出改變。
def update(self,tree_idx,p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
while tree_idx!=0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
最後這邊要實作從頭開始找值的方法。
def get_leaf(self,v):
parent_idx = 0
while True:
cl_idx = 2 * parent_idx + 1
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree):
leaf_idx = parent_idx
break
else:
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx,self.tree[leaf_idx],self.data[data_idx]
樹狀結構根找值我們介紹到這邊,下章接著講怎跟整個訓練做配合的,我們明天見拉~
莫凡RL程式碼參考:https://bre.is/tCA5GuPc